
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import torch
from torch.nn import CrossEntropyLoss
import math

from ..deberta import *
from ..utils import *

class SequenceClassificationModel(NNModule):
  """Model for sequence classification.
  This module is composed of the BERT model with a linear layer on top of
  the pooled output.

  Params:
    `config`: a ModelConfig class instance with the configuration to build a new model.
    `num_labels`: the number of classes for the classifier. Default = 2.

  Inputs:
    `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length]
      with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts
      `extract_features.py`, `run_classifier.py` and `run_squad.py`)
    `token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token
      types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to
      a `sentence B` token (see BERT paper for more details).
    `input_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices
      selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max
      input sequence length in the current batch. It's the mask that we typically use for attention when
      a batch has varying length sentences.
    `labels`: labels for the classification output: torch.LongTensor of shape [batch_size]
      with indices selected in [0, ..., num_labels].

  Outputs:
    if `labels` is not `None`:
      Outputs the CrossEntropy classification loss of the output with the labels.
    if `labels` is `None`:
      Outputs the classification logits of shape [batch_size, num_labels].

  Example usage:
  ```python
  # Already been converted into WordPiece token ids
  input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
  input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
  token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]])

  config = ModelConfig(vocab_size_or_config_json_file=32000, hidden_size=768,
    num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072)

  num_labels = 2

  model = SequenceClassificationModel(config, num_labels)
  logits = model(input_ids, token_type_ids, input_mask)
  ```
  """
  def __init__(self, config, num_labels=2, drop_out=None, pre_trained=None):
    super().__init__(config)
    self.num_labels = num_labels
    self.bert = DeBERTa(config, pre_trained=pre_trained)
    if pre_trained is not None:
      self.config = self.bert.config
    else:
      self.config = config
    pool_config = PoolConfig(self.config)
    output_dim = self.bert.config.hidden_size
    self.pooler = ContextPooler(pool_config)
    output_dim = self.pooler.output_dim()

    self.classifier = nn.Linear(output_dim, num_labels)
    drop_out = self.config.hidden_dropout_prob if drop_out is None else drop_out
    self.dropout = StableDropout(drop_out)
    self.apply(self.init_weights)
    self.bert.apply_state()

  def forward(self, input_ids, type_ids=None, input_mask=None, labels=None, position_ids=None, **kwargs):
    encoder_layers = self.bert(input_ids, type_ids, input_mask, position_ids=position_ids, output_all_encoded_layers=True)
    pooled_output = self.pooler(encoder_layers[-1])
    pooled_output = self.dropout(pooled_output)
    logits = self.classifier(pooled_output)

    loss = 0
    if labels is not None:
      if self.num_labels ==1:
        # regression task
        loss_fn = nn.MSELoss()
        logits=logits.view(-1).to(labels.dtype)
        loss = loss_fn(logits, labels.view(-1))
      elif labels.dim()==1 or labels.size(-1)==1:
        label_index = (labels >= 0).nonzero()
        labels = labels.long()
        if label_index.size(0) > 0:
          labeled_logits = torch.gather(logits, 0, label_index.expand(label_index.size(0), logits.size(1)))
          labels = torch.gather(labels, 0, label_index.view(-1))
          loss_fct = CrossEntropyLoss()
          loss = loss_fct(labeled_logits.view(-1, self.num_labels).float(), labels.view(-1))
        else:
          loss = torch.tensor(0).to(logits)
      else:
        log_softmax = torch.nn.LogSoftmax(-1)
        label_confidence = 1
        loss = -((log_softmax(logits)*labels).sum(-1)*label_confidence).mean()

    return (logits,loss)
